import torch
import torch.nn.functional as F
import numpy as np
import os
from scipy.spatial import cKDTree
import trimesh
from pytorch3d.ops import knn_points
from tqdm import tqdm


class Dataset:
    def __init__(self, pointcloud, part=1, scene_name='Barn', old_pc_bbx=None, old_shape_center=None, old_shape_scale=None):
        super(Dataset, self).__init__()
        print('Load data: Begin')
        self.device = torch.device('cuda')

        ### process data
        self.rescale = 1
        if old_pc_bbx is None:
            pc_bbx = np.array([[np.min(pointcloud[:, 0]), np.max(pointcloud[:, 0])],
                               [np.min(pointcloud[:, 1]), np.max(pointcloud[:, 1])],
                               [np.min(pointcloud[:, 2]), np.max(pointcloud[:, 2])]])
        else:
            pc_bbx = old_pc_bbx


        split_size = (pc_bbx[:,1] - pc_bbx[:,0]) / 2
        split_size[1] = 0

        shape_scale = np.max(
            [np.max(pointcloud[:, 0]) - np.min(pointcloud[:, 0]), np.max(pointcloud[:, 1]) - np.min(pointcloud[:, 1]),
             np.max(pointcloud[:, 2]) - np.min(pointcloud[:, 2])])
        shape_scale /= self.rescale
        shape_center = [(np.max(pointcloud[:, 0]) + np.min(pointcloud[:, 0])) / 2,
                        (np.max(pointcloud[:, 1]) + np.min(pointcloud[:, 1])) / 2,
                        (np.max(pointcloud[:, 2]) + np.min(pointcloud[:, 2])) / 2]
        if old_shape_center is None:
            self.shape_center = torch.tensor(shape_center,dtype=torch.float).to(self.device)
        else:
            self.shape_center = old_shape_center
            print('load old shape center.')
        if old_shape_scale is None:
            self.shape_scale = torch.tensor([shape_scale],dtype=torch.float).to(self.device)
        else:
            self.shape_scale = old_shape_scale
            print('load old shape center.')
        pointcloud = pointcloud - self.shape_center.detach().cpu().numpy()
        pointcloud = pointcloud / self.shape_scale.detach().cpu().numpy()

        POINT_NUM = pointcloud.shape[0] // 60
        POINT_NUM_GT = pointcloud.shape[0] // 60 * 60
        QUERY_EACH = 1000000 // POINT_NUM_GT
        if QUERY_EACH < 10:
            QUERY_EACH = 10
        self.query_each = QUERY_EACH

        point_idx = np.random.choice(pointcloud.shape[0], POINT_NUM_GT, replace=False)
        self.downsample_idx = torch.tensor(point_idx, dtype=torch.long).to(self.device)
        pointcloud = pointcloud[point_idx, :]
        ptree = cKDTree(pointcloud)
        sigmas = []
        for p in np.array_split(pointcloud, 100, axis=0):
            d = ptree.query(p, 51)
            sigmas.append(d[0][:, -1])

        self.pc_bbx = pc_bbx
        self.object_bbox_min = pc_bbx[:,0] - 0.05
        self.object_bbox_max = pc_bbx[:,1] + 0.05
        self.object_bbox_min = np.array([-0.6, -0.6, -0.6])
        self.object_bbox_max = np.array([0.6, 0.6, 0.6])
        print('bd:', self.object_bbox_min, self.object_bbox_max)

        sigmas = np.concatenate(sigmas)
        sample = []
        sample_near = []
        sample_near_idx = []
        for i in tqdm(range(QUERY_EACH)):
            thres = 0.25
            scale = thres if thres * np.sqrt(POINT_NUM_GT / 20000) < thres else thres * np.sqrt(POINT_NUM_GT / 20000)
            tt = pointcloud + scale * np.expand_dims(sigmas, -1) * np.random.normal(0.0, 1.0, size=pointcloud.shape)
            sample.append(tt)
            tt = tt.reshape(-1, POINT_NUM, 3)
            sample_near_tmp = []
            sample_near_idx_tmp = []
            for j in range(tt.shape[0]):
                knns = knn_points(torch.tensor(tt[j]).float().cuda()[None],
                                  torch.tensor(pointcloud).float().cuda()[None], K=1)
                nearest_idx = knns.idx[0][:, 0].cpu().numpy()
                sample_near_idx_tmp.append(nearest_idx)
                nearest_points = pointcloud[nearest_idx]
                nearest_points = np.asarray(nearest_points).reshape(-1, 3)
                sample_near_tmp.append(nearest_points)
            sample_near_tmp = np.asarray(sample_near_tmp)
            sample_near_tmp = sample_near_tmp.reshape(-1, 3)
            sample_near.append(sample_near_tmp)
            sample_near_idx_tmp = np.asarray(sample_near_idx_tmp).reshape(-1)
            sample_near_idx.append(sample_near_idx_tmp)

        for i in tqdm(range(1)):
            thres = 0.75
            scale = thres if thres * np.sqrt(POINT_NUM_GT / 20000) < thres else thres * np.sqrt(POINT_NUM_GT / 20000)
            tt = pointcloud + scale * np.expand_dims(sigmas, -1) * np.random.normal(0.0, 1.0, size=pointcloud.shape)
            sample.append(tt)
            tt = tt.reshape(-1, POINT_NUM, 3)
            sample_near_tmp = []
            sample_near_idx_tmp = []
            for j in range(tt.shape[0]):
                knns = knn_points(torch.tensor(tt[j]).float().cuda()[None],
                                  torch.tensor(pointcloud).float().cuda()[None], K=1)
                nearest_idx = knns.idx[0][:, 0].cpu().numpy()
                sample_near_idx_tmp.append(nearest_idx)
                nearest_points = pointcloud[nearest_idx]
                nearest_points = np.asarray(nearest_points).reshape(-1, 3)
                sample_near_tmp.append(nearest_points)
            sample_near_tmp = np.asarray(sample_near_tmp)
            sample_near_tmp = sample_near_tmp.reshape(-1, 3)
            sample_near.append(sample_near_tmp)
            sample_near_idx_tmp = np.asarray(sample_near_idx_tmp).reshape(-1)
            sample_near_idx.append(sample_near_idx_tmp)

        sample = np.asarray(sample)
        sample_near = np.asarray(sample_near)
        sample_near_idx = np.asarray(sample_near_idx, dtype=np.int64).reshape(-1)

        ### end process data

        self.point = np.asarray(sample_near).reshape(-1, 3)
        self.sample = np.asarray(sample).reshape(-1, 3)
        self.point_gt = np.asarray(pointcloud).reshape(-1, 3)
        self.sample_points_num = self.sample.shape[0] - 1

        self.point = torch.from_numpy(self.point).to(self.device).float()
        self.sample = torch.from_numpy(self.sample).to(self.device).float()
        self.point_gt = torch.from_numpy(self.point_gt).to(self.device).float()
        self.point_idx = torch.from_numpy(sample_near_idx).to(self.device).long()

        print('NP Load data: End')

    def get_train_data(self, batch_size):
        index_coarse = np.random.choice(10, 1)
        index_fine = np.random.choice(self.sample_points_num // 10, batch_size, replace=False)
        index = index_fine * 10 + index_coarse  # for accelerating random choice operation
        points = self.point[index]
        sample = self.sample[index]
        points_idx = self.point_idx[index]
        return points, sample, self.point_gt, points_idx
